{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Variational Inference in Stan\n",
    "\n",
    "Variational inference is a scalable technique for approximate Bayesian inference.\n",
    "Stan implements an automatic variational inference algorithm,\n",
    "called Automatic Differentiation Variational Inference (ADVI)\n",
    "which searches over a family of simple densities to find the best\n",
    "approximate posterior density.\n",
    "ADVI produces an estimate of the parameter means together with a sample\n",
    "from the approximate posterior density.\n",
    "\n",
    "ADVI approximates the variational objective function, the evidence lower bound or ELBO,\n",
    "using stochastic gradient ascent.\n",
    "The algorithm ascends these gradients using an adaptive stepsize sequence\n",
    "that has one parameter ``eta`` which is adjusted during warmup.\n",
    "The number of draws used to approximate the ELBO is denoted by ``elbo_samples``. \n",
    "ADVI heuristically determines a rolling window over which it computes\n",
    "the average and the median change of the ELBO.\n",
    "When this change falls below a threshold, denoted by ``tol_rel_obj``,\n",
    "the algorithm is considered to have converged."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: variational inference for model ``bernoulli.stan``\n",
    "\n",
    "In CmdStanPy, the `CmdStanModel` class method `variational` invokes CmdStan with\n",
    "`method=variational` and returns an estimate of the approximate posterior\n",
    "mean of all model parameters as well as a set of draws from this approximate\n",
    "posterior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from cmdstanpy.model import CmdStanModel\n",
    "from cmdstanpy.utils import cmdstan_path\n",
    "\n",
    "bernoulli_dir = os.path.join(cmdstan_path(), 'examples', 'bernoulli')\n",
    "stan_file = os.path.join(bernoulli_dir, 'bernoulli.stan')\n",
    "data_file = os.path.join(bernoulli_dir, 'bernoulli.data.json')\n",
    "# instantiate, compile bernoulli model\n",
    "model = CmdStanModel(stan_file=stan_file)\n",
    "# run CmdStan's variational inference method, returns object `CmdStanVB`\n",
    "vi = model.variational(data=data_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The class [`CmdStanVB`](https://cmdstanpy.readthedocs.io/en/latest/api.html#stanvariational) provides the following properties to access information about the parameter names, estimated means, and the sample:\n",
    " + `column_names`\n",
    " + `variational_params_dict`\n",
    " + `variational_params_np`\n",
    " + `variational_params_pd`\n",
    " + `variational_sample`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(vi.column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vi.variational_params_dict['theta'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(vi.variational_sample.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These estimates are only valid if the algorithm has converged to a good\n",
    "approximation. When the algorithm fails to do so, the `variational`\n",
    "method will throw a `RuntimeError`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_fail = CmdStanModel(stan_file='eta_should_fail.stan')\n",
    "vi_fail = model_fail.variational()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unless you set `require_converged=False`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vi_fail = model_fail.variational(require_converged=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This lets you inspect the output to try to diagnose the issue with the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vi_fail.variational_params_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [api docs](https://cmdstanpy.readthedocs.io/en/latest/api.html), section [`CmdStanModel.variational`](https://cmdstanpy.readthedocs.io/en/latest/api.html#cmdstanpy.CmdStanModel.variational) for a full description of all arguments."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
